import os
import sys
import gzip
from collections import defaultdict
from Bio import SeqIO
from pylab import *


target, transcript = sys.argv[1:]


def parse_alignments(path):
    header = """\
psLayout version 3

match	mis- 	rep. 	N's	Q gap	Q gap	T gap	T gap	strand	Q        	Q   	Q    	Q  	T        	T   	T    	T  	block	blockSizes 	qStarts	 tStarts
     	match	match	   	count	bases	count	bases	      	name     	size	start	end	name     	size	start	end	count
---------------------------------------------------------------------------------------------------------------------------------------------------------------
"""
    print("Reading", path)
    handle = gzip.open(path, "rt")
    line1 = next(handle)
    line2 = next(handle)
    line3 = next(handle)
    line4 = next(handle)
    line5 = next(handle)
    assert line1 + line2 + line3 + line4 + line5 == header
    for line1 in handle:
        line2 = next(handle)
        words1 = line1.split()
        words2 = line2.split()
        strand1 = words1[8]
        strand2 = words2[8]
        assert strand1 == '+'
        assert strand2 == '-'
        qName1 = words1[9]
        qName2 = words2[9]
        assert qName1 == qName2
        tName1 = words1[13]
        tName2 = words2[13]
        assert tName1 == tName2
        tSize1 = int(words1[14])
        tSize2 = int(words2[14])
        assert tSize1 == tSize2
        tStart1 = int(words1[15])
        tStart2 = int(words2[15])
        tEnd1 = int(words1[16])
        tEnd2 = int(words2[16])
        yield (tName1, tEnd2)
    handle.close()
    

directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
filename = "%s.fa" % target
path = os.path.join(directory, filename)
print("Reading", path)
with open(path) as handle:
    records = SeqIO.parse(handle, 'fasta')
    for record in records:
        if record.id == transcript:
            break
    else:
        raise Exception("Failed to find %s" % transcript)
description, variant, genetype = record.description.rsplit(", ", 2)
assert genetype == target
words = description.split()
word = words[-1]
assert word.startswith("(")
assert word.endswith(")")
gene = word[1:-1]


directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
filename = "%s.psl" % target
path = os.path.join(directory, filename)
print("Reading", path)
handle = open(path)
for line in handle:
    words = line.split()
    assert len(words) == 21
    qName = words[9]
    qSize = int(words[10])
    if qName == transcript:
        qStarts = [int(qStart) for qStart in words[19].split(",")[:-1]]
        strand = words[8]
        qStart = qStarts.pop(0)
        if strand == '+':
            assert qStart == 0
        # for strand == '-', qStart corresponds to the start of the poly(A) tail
        boundaries = array(qStarts)
        length = qSize
        break
else:
    raise Exception("Failed to find %s" % transcript)
handle.close()

directory = "/osc-fs_home/mdehoon/Data/CASPARs/"

def read_data():
    profiles = {}
    subdirectory = os.path.join(directory, "MiSeq", "PSL")
    filenames = os.listdir(subdirectory)
    for filename in filenames:
        if not filename.endswith(".psl.gz"):
            continue
        terms = filename.split(".")
        if terms[1] != target:
            continue
        library = terms[0]
        if not library.startswith("t"):
            # include time course samples only
            continue
        profile = zeros(length)
        path = os.path.join(subdirectory, filename)
        alignments = parse_alignments(path)
        for alignment in alignments:
            tName, tEnd = alignment
            if tName == transcript:
                profile[tEnd-1] += 1
        profiles[library] = profile
    return profiles

figure(figsize=(2.5,1.75))

profiles = read_data()
for library in profiles:
    profile = profiles[library]
    plot(profile, color='black', alpha=0.25)

ymin, ymax = ylim()
for boundary in boundaries:
    plot([boundary-0.5,boundary-0.5], [0, ymax], 'b--', linewidth=1)

xlim(0,length)
ylim(ymin, ymax)

xticks(fontsize=8)
yticks(fontsize=8)

xlabel("Position of the 3' end with\nrespect to the %s\n(%s) mRNA 5' end [nucleotides]" % (transcript, gene), fontsize=8)
ylabel("Frequency", fontsize=8)

subplots_adjust(bottom=0.4, top=0.9, left=0.20, right=0.90)

filename = "figure_termination_site_%s_%s_timecourse.svg" % (target, gene)
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_termination_site_%s_%s_timecourse.png" % (target, gene)
print("Saving figure to %s" % filename)
savefig(filename)
